from ray.rllib.agents.ppo import PPOTrainer, PPOTorchPolicy

from grl.envs.mujoco_multi_agent_env import MujocoMultiAgentEnv 
from grl.envs.mujoco_adv_env import MujocoAdvEnv
from grl.rl_apps.scenarios.catalog import scenario_catalog
from grl.rl_apps.scenarios.catalog.common import default_if_creating_ray_head
from grl.rl_apps.scenarios.psro_scenario import PSROScenario
from grl.rl_apps.scenarios.stopping_conditions import *
from grl.rl_apps.scenarios.trainer_configs.mujoco_psro_configs import *
from grl.rl_apps.scenarios.trainer_configs.mujoco_psro_configs import *

mujoco_psro_ppo = PSROScenario(
    name=f"mujoco_multi_agent_psro",
    ray_cluster_cpus=default_if_creating_ray_head(default=8),
    ray_cluster_gpus=default_if_creating_ray_head(default=0),
    ray_object_store_memory_cap_gigabytes=1,
    env_class=MujocoMultiAgentEnv,
    env_config={
        "append_valid_actions_mask_to_obs": False,
        "continuous_action_space": True,
    },
    mix_metanash_with_uniform_dist_coeff=0.0,
    allow_stochastic_best_responses=False,
    trainer_class=PPOTrainer,
    policy_classes={
        "metanash": PPOTorchPolicy,
        "best_response": PPOTorchPolicy,
        "eval": PPOTorchPolicy,
    },
    num_eval_workers=8,
    games_per_payoff_eval=1,
    p2sro=False,
    p2sro_payoff_table_exponential_avg_coeff=None,
    p2sro_sync_with_payoff_table_every_n_episodes=None,
    single_agent_symmetric_game=False,
    get_trainer_config=mujoco_psro_ppo_params,
    psro_get_stopping_condition=lambda: TimeStepsSingleBRRewardPlateauStoppingCondition(
        br_policy_id="best_response",
        dont_check_plateau_before_n_steps=int(100e3),
        check_plateau_every_n_steps=int(50e3),
        minimum_reward_improvement_otherwise_plateaued=0.02,
        max_train_steps=int(1e6),
    ),
    calc_exploitability_for_openspiel_env=False,
)
scenario_catalog.add(mujoco_psro_ppo)


